library(dplyr)
library(tidyr)
library(Rcpp)
library(abind)

set.seed(123)
rm(list=ls())

### load the functions needed to compute negative log likelihood
Rcpp::sourceCpp("src/misc.cpp")

### subdirectories
in_dir <- "./data"
outdir <- "./output"

### load data
(loaded <-  load(file.path(in_dir, "analysis_data.RData")))
estimates <- readRDS(file.path(outdir, "rum_estimates.rds"))
  
### constants
nfolds <- 10

countries <- c("CZ","DE","ES","FR","HU","IT","NL","SE","GB")
names(countries) <- countries

countries.lab <- countrycode::countrycode(countries, 'iso2c', 'country.name')
names(countries.lab) <- countries
countries.lab['GB'] <- 'England'
countries.lab <- sort(countries.lab)

ucov <- c("female", "university", "age_under30", "age_over50")
dims <- c("econ","cult","populism","ant","ppl","man")
dims.lab <- c("Economic Left-Right","Cultural Conservatism","Populism","Anti-elitism","People-centrism","Manichean outlook")
mdims <- c('econ', 'cult')
sdims <- c('econ', 'cult', 'populism') 
popvars <- c('populism','ant','ppl','man')

## prepare sample pieces for vote choice models
sam <- lapply(countries, function(cn) {
  v <- vdata[[cn]] |> tibble::rowid_to_column("rowid")
  p <- subset(pdata, cntry==cn)
  menu <- v |> count(cntry_party_ivc) |> filter(n>10) |> 
    pull(cntry_party_ivc) |> intersect(p$cntry_party) |> sort() 
  v <- v |>
    filter(cntry_party_ivc %in% menu) |>
    select(rowid, wt, cntry_party_ivc, any_of(dims), any_of(ucov)) |>
    mutate(wt = wt/mean(wt))
  p <- subset(p, cntry_party %in% menu, select=c('cntry_party', dims))
  
# form the folds  
  fold <- (v |> nrow() |> runif() |> row_number()) %% nfolds
  
# party-level intercepts and control variables
  tobin <- setdiff(menu, v |> count(cntry_party_ivc) |> slice_max(n) |> pull(cntry_party_ivc)) 
  dummies <- outer(menu, tobin, function(x,y) as.numeric(x==y))
  xval1 <- xval0 <- replicate(nrow(v), dummies)
  xlab1 <- xlab0 <- paste0("j_",make.names(tobin))
  for (uv in ucov) { 
    xval1 <- abind(xval1, outer(dummies, v[[uv]]), along = 2L)
    xlab1 <- c(xlab1, paste0("i_",make.names(tobin),"_",uv))
  }
# coordinates in the political space
  Q.df <- as.matrix(p[match(menu, p$cntry_party), sdims])
  P.df <- as.matrix(v[, sdims])
# compile pieces into a list
  wt <- v$wt/mean(v$wt)
  y <- match(v$cntry_party_ivc, menu) 
  lapply(0:(nfolds-1L), function(f) {
    P_train = P.df[fold!=f,]
    P_val = P.df[fold==f,] 
    wt_train = wt[fold!=f]
    wt_val = wt[fold==f]
    xval0_train = xval0[,,fold!=f]
    xval0_val = xval0[,,fold==f]
    xval1_train = xval1[,,fold!=f]
    xval1_val = xval1[,,fold==f]
    gt_train = y[fold!=f]-1L
    gt_val = y[fold==f]-1L
    list (
      train = list(
        n = nrow(P_train),
        wt = wt_train,
        Q0.M=Q.df[,mdims],
        P0.M=P_train[,mdims],    
        Q.M=Q.df[,sdims],
        P.M=P_train[,sdims],
        X0=xval0_train,
        X1=xval1_train,
        gt=gt_train
      ),
      val = list(
        n = nrow(P_val),
        wt = wt_val,
        Q0.M=Q.df[,mdims],
        P0.M=P_val[,mdims],    
        Q.M=Q.df[,sdims],
        P.M=P_val[,sdims],
        X0=xval0_val,
        X1=xval1_val,
        gt=gt_val
      )      
    )
  })
})

make_nll <- function(separable=TRUE, discount=TRUE, controls=FALSE, env) { 
  ndim <- ncol(env[['Q']])

  if (controls) {
    if (discount) {
      f <- function(alpha, lambda, beta) rcpp_nll_xd110(lambda, P, Q, gt, wt, beta, X, alpha) 
      formals(f)[['alpha']] <- 1 
      } else {
      f <- function(lambda, beta) rcpp_nll_x(lambda, P, Q, gt, wt, beta, X)
      }
    formals(f)[['beta']] <- rep(0, dim(env[['X']])[2L])
  }
  else {
    f <- function(lambda) rcpp_nll_base(lambda, P, Q, gt, wt)
  }
  
  if (separable) {
    formals(f)[['lambda']] <- rep(1, ndim) 
  } else {
    formals(f)[['lambda']] <- c(rep(1, ndim), rep(0, (ndim*(ndim-1L)) %/% 2L))
  }
  
  list2env(env,environment(f))
  return(f)
}

### models with 3 dimensions
### A1 - separable
### A2 - nonseparable
### A3 - nonseparable + intercepts
### A4 - nonseparable + controls
### A5 - nonseparable + controls + discounting
## C are 2-dimensional alternatives

xvalpred <- list()
for (cn in countries) {
  xvalpred[[cn]] <- list()
  for (f in seq_along(sam[[cn]])) {
    xvalpred[[cn]][[f]] <- list()
    s <- sam[[cn]][[f]]
    sam_int <- sam_con <- s[['train']][c('gt','wt')]
    sam_int[['X']] <- s[['train']][['X0']] # intercepts only
    sam_con[['X']] <- s[['train']][['X1']] # intercepts and controls
    sam_int_val <- sam_con_val <- s[['val']][c('gt','wt')]
    sam_int_val[['X']] <- s[['val']][['X0']] # intercepts only
    sam_con_val[['X']] <- s[['val']][['X1']] # intercepts and controls

    n <- s[['train']][['n']]
    for (t in c('M')) {
      two <- list(Q = s[['train']][[paste0('Q0.',t)]], P= s[['train']][[paste0('P0.',t)]])
      three <- list(Q =s[['train']][[paste0('Q.',t)]], P = s[['train']][[paste0('P.',t)]])
      two_val <- list(Q = s[['val']][[paste0('Q0.',t)]], P= s[['val']][[paste0('P0.',t)]])
      three_val <- list(Q =s[['val']][[paste0('Q.',t)]], P = s[['val']][[paste0('P.',t)]])

      nlls <- list()
      # 2-dimensional
      nlls[['C1']] <- make_nll(separable=TRUE, discount=FALSE, controls=FALSE, env=c(two, sam_int))
      nlls[['C2']] <- make_nll(separable=FALSE, discount=FALSE, controls=FALSE, env=c(two, sam_int))
      nlls[['C3']] <- make_nll(separable=FALSE, discount=FALSE, controls=TRUE, env=c(two, sam_int))
      nlls[['C4']] <- make_nll(separable=FALSE, discount=FALSE, controls=TRUE, env=c(two, sam_con))
      nlls[['C5']] <- make_nll(separable=FALSE, discount=TRUE, controls=TRUE, env=c(two, sam_con))
      # 3-dimensional
      nlls[['A1']] <- make_nll(separable=TRUE, discount=FALSE, controls=FALSE, env=c(three, sam_int))
      nlls[['A2']] <- make_nll(separable=FALSE, discount=FALSE, controls=FALSE, env=c(three, sam_int))
      nlls[['A3']] <- make_nll(separable=FALSE, discount=FALSE, controls=TRUE, env=c(three, sam_int))
      nlls[['A4']] <- make_nll(separable=FALSE, discount=FALSE, controls=TRUE, env=c(three, sam_con))
      nlls[['A5']] <- make_nll(separable=FALSE, discount=TRUE, controls=TRUE, env=c(three, sam_con))

      params <- lapply(nlls, function(fn) {
        co <- tryCatch(stats4::mle(minuslogl=fn, nobs=n, control=list(maxit=250))@coef,
                      error=function(e) return(NULL))
        names(nams) <- nams <- c("alpha","beta","lambda")
        lapply(nams, function(n) {
          text <- grep(paste0("^", n), names(co), val=TRUE)
          matches <- regmatches(text, gregexpr("\\d+\\.?\\d*", text))
          sorter <- sapply(matches, function(m) {
            if (length(m)==0) return(0)
            else return(as.integer(m))
          })
          if (length(text)==0) return(0)
          else return(co[text[order(sorter)]])
        })
      })

      preds <- list()
      # 2-dimensional
      preds[['C1']] <- with(c(two_val, sam_int_val, params[['C1']]),
                            tryCatch(rcpp_pred_base(lambda, P, Q), error=function(e) NA))
      preds[['C2']] <- with(c(two_val, sam_int_val, params[['C2']]),
                            tryCatch(rcpp_pred_base(lambda, P, Q), error=function(e) NA))
      preds[['C3']] <- with(c(two_val, sam_int_val, params[['C3']]),
                            tryCatch(rcpp_pred_x(lambda, P, Q, beta, X), error=function(e) NA))
      preds[['C4']] <- with(c(two_val, sam_con_val, params[['C4']]),
                            tryCatch(rcpp_pred_x(lambda, P, Q, beta, X), error=function(e) NA))
      preds[['C5']] <- with(c(two_val, sam_con_val, params[['C5']]),
                            tryCatch(rcpp_pred_xd110(lambda, P, Q, beta, X, alpha), error=function(e) NA))

      # 3-dimensional
      preds[['A1']] <- with(c(three_val, sam_int_val, params[['A1']]),
                            tryCatch(rcpp_pred_base(lambda, P, Q), error=function(e) NA))
      preds[['A2']] <- with(c(three_val, sam_int_val, params[['A2']]),
                            tryCatch(rcpp_pred_base(lambda, P, Q), error=function(e) NA))
      preds[['A3']] <- with(c(three_val, sam_int_val, params[['A3']]),
                            tryCatch(rcpp_pred_x(lambda, P, Q, beta, X), error=function(e) NA))
      preds[['A4']] <- with(c(three_val, sam_con_val, params[['A4']]),
                            tryCatch(rcpp_pred_x(lambda, P, Q, beta, X), error=function(e) NA))
      preds[['A5']] <- with(c(three_val, sam_con_val, params[['A5']]),
                            tryCatch(rcpp_pred_xd110(lambda, P, Q, beta, X, alpha), error=function(e) NA))
      for (mod in names(preds)) {
        xvalpred[[cn]][[f]][[mod]] <- data.frame(pr = as.vector(preds[[mod]]),
                                                 gt = s[['val']][['gt']],
                                                 fold = f,
                                                 vb = t,
                                                 wt = s[['val']][['wt']])
      }
    }
  }
}

pile <- lapply(names(xvalpred), function(cn) {
  lapply(1:10, function(f) {
    lapply(names(xvalpred[[1L]][[1L]]), function(mod) {
      xvalpred[[cn]][[f]][[mod]] |>
        mutate(
          cntry = cn,
          fold = f,
          model = mod,
          match = as.numeric(pr == gt)
        )
    }) |> bind_rows()
  }) |> bind_rows()
}) |> bind_rows() |>
  filter(!is.na(pr))

## accuracy
acc <- pile |>
  group_by(cntry, model) |>
  summarize(accuracy = sum(match*wt)/sum(wt)) |>
  ungroup() |>
  mutate(
    column = paste0('c', substring(model,2L)),
    ndim= case_match(substring(model,1L, 1L),
                     'A' ~ '3-dim.',
                     'C' ~ '2-dim.',),
    country = countries.lab[cntry],
    stat = "Acc.",
    val = paste0(formatC(accuracy*100, digits=1, format='f'), "%"),
  ) |>
  select(country, ndim, stat, column, val)

aics <- data.frame(nam = names(estimates),
                   AIC = sapply(estimates, stats4::AIC)) |> 
  mutate(
    column = paste0('c', substring(nam,2L,2L)),
    ndim= case_match(substring(nam,1L, 1L),
                              'A' ~ '3-dim.',
                              'C' ~ '2-dim.',),
    country = countries.lab[substring(nam, 4L, 5L)],
    stat = "AIC",
    val = formatC(AIC, digits=1, format='f')) |>
  select(country, ndim, stat, column, val)
tab <- bind_rows(acc, aics) |>
  pivot_wider(names_from=column, values_from=val) |>
  arrange(country, ndim, stat) |>
  mutate(
    country = ifelse(row_number() %% 4==1, country, ""),
    ndim = ifelse(row_number() %% 2==1, ndim, "")
  ) 
print(xtable::xtable(tab, type = "latex"),
      file = file.path(outdir, "table2.tex"),
      include.rownames=FALSE, floating = FALSE, tabular.environment="longtable")


